{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cb495041",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# 门控循环单元（GRU）\n",
    ":label:`sec_gru`\n",
    "\n",
    "在 :numref:`sec_bptt`中，\n",
    "我们讨论了如何在循环神经网络中计算梯度，\n",
    "以及矩阵连续乘积可以导致梯度消失或梯度爆炸的问题。\n",
    "下面我们简单思考一下这种梯度异常在实践中的意义：\n",
    "\n",
    "* 我们可能会遇到这样的情况：早期观测值对预测所有未来观测值具有非常重要的意义。\n",
    "  考虑一个极端情况，其中第一个观测值包含一个校验和，\n",
    "  目标是在序列的末尾辨别校验和是否正确。\n",
    "  在这种情况下，第一个词元的影响至关重要。\n",
    "  我们希望有某些机制能够在一个记忆元里存储重要的早期信息。\n",
    "  如果没有这样的机制，我们将不得不给这个观测值指定一个非常大的梯度，\n",
    "  因为它会影响所有后续的观测值。\n",
    "* 我们可能会遇到这样的情况：一些词元没有相关的观测值。\n",
    "  例如，在对网页内容进行情感分析时，\n",
    "  可能有一些辅助HTML代码与网页传达的情绪无关。\n",
    "  我们希望有一些机制来*跳过*隐状态表示中的此类词元。\n",
    "* 我们可能会遇到这样的情况：序列的各个部分之间存在逻辑中断。\n",
    "  例如，书的章节之间可能会有过渡存在，\n",
    "  或者证券的熊市和牛市之间可能会有过渡存在。\n",
    "  在这种情况下，最好有一种方法来*重置*我们的内部状态表示。\n",
    "\n",
    "在学术界已经提出了许多方法来解决这类问题。\n",
    "其中最早的方法是\"长短期记忆\"（long-short-term memory，LSTM）\n",
    " :cite:`Hochreiter.Schmidhuber.1997`，\n",
    "我们将在 :numref:`sec_lstm`中讨论。\n",
    "门控循环单元（gated recurrent unit，GRU）\n",
    " :cite:`Cho.Van-Merrienboer.Bahdanau.ea.2014`\n",
    "是一个稍微简化的变体，通常能够提供同等的效果，\n",
    "并且计算 :cite:`Chung.Gulcehre.Cho.ea.2014`的速度明显更快。\n",
    "由于门控循环单元更简单，我们从它开始解读。\n",
    "\n",
    "## 门控隐状态\n",
    "\n",
    "门控循环单元与普通的循环神经网络之间的关键区别在于：\n",
    "前者支持隐状态的门控。\n",
    "这意味着模型有专门的机制来确定应该何时更新隐状态，\n",
    "以及应该何时重置隐状态。\n",
    "这些机制是可学习的，并且能够解决了上面列出的问题。\n",
    "例如，如果第一个词元非常重要，\n",
    "模型将学会在第一次观测之后不更新隐状态。\n",
    "同样，模型也可以学会跳过不相关的临时观测。\n",
    "最后，模型还将学会在需要的时候重置隐状态。\n",
    "下面我们将详细讨论各类门控。\n",
    "\n",
    "### 重置门和更新门\n",
    "\n",
    "我们首先介绍*重置门*（reset gate）和*更新门*（update gate）。\n",
    "我们把它们设计成$(0, 1)$区间中的向量，\n",
    "这样我们就可以进行凸组合。\n",
    "重置门允许我们控制“可能还想记住”的过去状态的数量；\n",
    "更新门将允许我们控制新状态中有多少个是旧状态的副本。\n",
    "\n",
    "我们从构造这些门控开始。 :numref:`fig_gru_1`\n",
    "描述了门控循环单元中的重置门和更新门的输入，\n",
    "输入是由当前时间步的输入和前一时间步的隐状态给出。\n",
    "两个门的输出是由使用sigmoid激活函数的两个全连接层给出。\n",
    "\n",
    "![在门控循环单元模型中计算重置门和更新门](../img/gru-1.svg)\n",
    ":label:`fig_gru_1`\n",
    "\n",
    "我们来看一下门控循环单元的数学表达。\n",
    "对于给定的时间步$t$，假设输入是一个小批量\n",
    "$\\mathbf{X}_t \\in \\mathbb{R}^{n \\times d}$\n",
    "（样本个数$n$，输入个数$d$），\n",
    "上一个时间步的隐状态是\n",
    "$\\mathbf{H}_{t-1} \\in \\mathbb{R}^{n \\times h}$\n",
    "（隐藏单元个数$h$）。\n",
    "那么，重置门$\\mathbf{R}_t \\in \\mathbb{R}^{n \\times h}$和\n",
    "更新门$\\mathbf{Z}_t \\in \\mathbb{R}^{n \\times h}$的计算如下所示：\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "\\mathbf{R}_t = \\sigma(\\mathbf{X}_t \\mathbf{W}_{xr} + \\mathbf{H}_{t-1} \\mathbf{W}_{hr} + \\mathbf{b}_r),\\\\\n",
    "\\mathbf{Z}_t = \\sigma(\\mathbf{X}_t \\mathbf{W}_{xz} + \\mathbf{H}_{t-1} \\mathbf{W}_{hz} + \\mathbf{b}_z),\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "其中$\\mathbf{W}_{xr}, \\mathbf{W}_{xz} \\in \\mathbb{R}^{d \\times h}$\n",
    "和$\\mathbf{W}_{hr}, \\mathbf{W}_{hz} \\in \\mathbb{R}^{h \\times h}$是权重参数，\n",
    "$\\mathbf{b}_r, \\mathbf{b}_z \\in \\mathbb{R}^{1 \\times h}$是偏置参数。\n",
    "请注意，在求和过程中会触发广播机制\n",
    "（请参阅 :numref:`subsec_broadcasting`）。\n",
    "我们使用sigmoid函数（如 :numref:`sec_mlp`中介绍的）\n",
    "将输入值转换到区间$(0, 1)$。\n",
    "\n",
    "### 候选隐状态\n",
    "\n",
    "接下来，让我们将重置门$\\mathbf{R}_t$\n",
    "与 :eqref:`rnn_h_with_state`\n",
    "中的常规隐状态更新机制集成，\n",
    "得到在时间步$t$的*候选隐状态*（candidate hidden state）\n",
    "$\\tilde{\\mathbf{H}}_t \\in \\mathbb{R}^{n \\times h}$。\n",
    "\n",
    "$$\\tilde{\\mathbf{H}}_t = \\tanh(\\mathbf{X}_t \\mathbf{W}_{xh} + \\left(\\mathbf{R}_t \\odot \\mathbf{H}_{t-1}\\right) \\mathbf{W}_{hh} + \\mathbf{b}_h),$$\n",
    ":eqlabel:`gru_tilde_H`\n",
    "\n",
    "其中$\\mathbf{W}_{xh} \\in \\mathbb{R}^{d \\times h}$\n",
    "和$\\mathbf{W}_{hh} \\in \\mathbb{R}^{h \\times h}$是权重参数，\n",
    "$\\mathbf{b}_h \\in \\mathbb{R}^{1 \\times h}$是偏置项，\n",
    "符号$\\odot$是Hadamard积（按元素乘积）运算符。\n",
    "在这里，我们使用tanh非线性激活函数来确保候选隐状态中的值保持在区间$(-1, 1)$中。\n",
    "\n",
    "与 :eqref:`rnn_h_with_state`相比，\n",
    " :eqref:`gru_tilde_H`中的$\\mathbf{R}_t$和$\\mathbf{H}_{t-1}$\n",
    "的元素相乘可以减少以往状态的影响。\n",
    "每当重置门$\\mathbf{R}_t$中的项接近$1$时，\n",
    "我们恢复一个如 :eqref:`rnn_h_with_state`中的普通的循环神经网络。\n",
    "对于重置门$\\mathbf{R}_t$中所有接近$0$的项，\n",
    "候选隐状态是以$\\mathbf{X}_t$作为输入的多层感知机的结果。\n",
    "因此，任何预先存在的隐状态都会被*重置*为默认值。\n",
    "\n",
    " :numref:`fig_gru_2`说明了应用重置门之后的计算流程。\n",
    "\n",
    "![在门控循环单元模型中计算候选隐状态](../img/gru-2.svg)\n",
    ":label:`fig_gru_2`\n",
    "\n",
    "### 隐状态\n",
    "\n",
    "上述的计算结果只是候选隐状态，我们仍然需要结合更新门$\\mathbf{Z}_t$的效果。\n",
    "这一步确定新的隐状态$\\mathbf{H}_t \\in \\mathbb{R}^{n \\times h}$\n",
    "在多大程度上来自旧的状态$\\mathbf{H}_{t-1}$和\n",
    "新的候选状态$\\tilde{\\mathbf{H}}_t$。\n",
    "更新门$\\mathbf{Z}_t$仅需要在\n",
    "$\\mathbf{H}_{t-1}$和$\\tilde{\\mathbf{H}}_t$\n",
    "之间进行按元素的凸组合就可以实现这个目标。\n",
    "这就得出了门控循环单元的最终更新公式：\n",
    "\n",
    "$$\\mathbf{H}_t = \\mathbf{Z}_t \\odot \\mathbf{H}_{t-1}  + (1 - \\mathbf{Z}_t) \\odot \\tilde{\\mathbf{H}}_t.$$\n",
    "\n",
    "每当更新门$\\mathbf{Z}_t$接近$1$时，模型就倾向只保留旧状态。\n",
    "此时，来自$\\mathbf{X}_t$的信息基本上被忽略，\n",
    "从而有效地跳过了依赖链条中的时间步$t$。\n",
    "相反，当$\\mathbf{Z}_t$接近$0$时，\n",
    "新的隐状态$\\mathbf{H}_t$就会接近候选隐状态$\\tilde{\\mathbf{H}}_t$。\n",
    "这些设计可以帮助我们处理循环神经网络中的梯度消失问题，\n",
    "并更好地捕获时间步距离很长的序列的依赖关系。\n",
    "例如，如果整个子序列的所有时间步的更新门都接近于$1$，\n",
    "则无论序列的长度如何，在序列起始时间步的旧隐状态都将很容易保留并传递到序列结束。\n",
    "\n",
    " :numref:`fig_gru_3`说明了更新门起作用后的计算流。\n",
    "\n",
    "![计算门控循环单元模型中的隐状态](../img/gru-3.svg)\n",
    ":label:`fig_gru_3`\n",
    "\n",
    "总之，门控循环单元具有以下两个显著特征：\n",
    "\n",
    "* 重置门有助于捕获序列中的短期依赖关系；\n",
    "* 更新门有助于捕获序列中的长期依赖关系。\n",
    "\n",
    "## 从零开始实现\n",
    "\n",
    "为了更好地理解门控循环单元模型，我们从零开始实现它。\n",
    "首先，我们读取 :numref:`sec_rnn_scratch`中使用的时间机器数据集：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a3bbf4ae",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:25:22.352712Z",
     "iopub.status.busy": "2023-08-18T07:25:22.351930Z",
     "iopub.status.idle": "2023-08-18T07:25:25.503017Z",
     "shell.execute_reply": "2023-08-18T07:25:25.502124Z"
    },
    "origin_pos": 2,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l\n",
    "\n",
    "batch_size, num_steps = 32, 35\n",
    "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11663bec",
   "metadata": {
    "origin_pos": 5
   },
   "source": [
    "### [**初始化模型参数**]\n",
    "\n",
    "下一步是初始化模型参数。\n",
    "我们从标准差为$0.01$的高斯分布中提取权重，\n",
    "并将偏置项设为$0$，超参数`num_hiddens`定义隐藏单元的数量，\n",
    "实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8664bb4a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:25:25.507494Z",
     "iopub.status.busy": "2023-08-18T07:25:25.506826Z",
     "iopub.status.idle": "2023-08-18T07:25:25.513659Z",
     "shell.execute_reply": "2023-08-18T07:25:25.512868Z"
    },
    "origin_pos": 7,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def get_params(vocab_size, num_hiddens, device):\n",
    "    num_inputs = num_outputs = vocab_size\n",
    "\n",
    "    def normal(shape):\n",
    "        return torch.randn(size=shape, device=device)*0.01\n",
    "\n",
    "    def three():\n",
    "        return (normal((num_inputs, num_hiddens)),\n",
    "                normal((num_hiddens, num_hiddens)),\n",
    "                torch.zeros(num_hiddens, device=device))\n",
    "\n",
    "    W_xz, W_hz, b_z = three()  # 更新门参数\n",
    "    W_xr, W_hr, b_r = three()  # 重置门参数\n",
    "    W_xh, W_hh, b_h = three()  # 候选隐状态参数\n",
    "    # 输出层参数\n",
    "    W_hq = normal((num_hiddens, num_outputs))\n",
    "    b_q = torch.zeros(num_outputs, device=device)\n",
    "    # 附加梯度\n",
    "    params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]\n",
    "    for param in params:\n",
    "        param.requires_grad_(True)\n",
    "    return params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c819ab4f",
   "metadata": {
    "origin_pos": 10
   },
   "source": [
    "### 定义模型\n",
    "\n",
    "现在我们将[**定义隐状态的初始化函数**]`init_gru_state`。\n",
    "与 :numref:`sec_rnn_scratch`中定义的`init_rnn_state`函数一样，\n",
    "此函数返回一个形状为（批量大小，隐藏单元个数）的张量，张量的值全部为零。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cc77ddd0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:25:25.517215Z",
     "iopub.status.busy": "2023-08-18T07:25:25.516638Z",
     "iopub.status.idle": "2023-08-18T07:25:25.520532Z",
     "shell.execute_reply": "2023-08-18T07:25:25.519776Z"
    },
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def init_gru_state(batch_size, num_hiddens, device):\n",
    "    return (torch.zeros((batch_size, num_hiddens), device=device), )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbb29722",
   "metadata": {
    "origin_pos": 15
   },
   "source": [
    "现在我们准备[**定义门控循环单元模型**]，\n",
    "模型的架构与基本的循环神经网络单元是相同的，\n",
    "只是权重更新公式更为复杂。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "648faa2f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:25:25.523984Z",
     "iopub.status.busy": "2023-08-18T07:25:25.523454Z",
     "iopub.status.idle": "2023-08-18T07:25:25.529513Z",
     "shell.execute_reply": "2023-08-18T07:25:25.528547Z"
    },
    "origin_pos": 17,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def gru(inputs, state, params):\n",
    "    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params\n",
    "    H, = state\n",
    "    outputs = []\n",
    "    for X in inputs:\n",
    "        Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)\n",
    "        R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)\n",
    "        H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)\n",
    "        H = Z * H + (1 - Z) * H_tilda\n",
    "        Y = H @ W_hq + b_q\n",
    "        outputs.append(Y)\n",
    "    return torch.cat(outputs, dim=0), (H,)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "024bb39d",
   "metadata": {
    "origin_pos": 20
   },
   "source": [
    "### [**训练**]与预测\n",
    "\n",
    "训练和预测的工作方式与 :numref:`sec_rnn_scratch`完全相同。\n",
    "训练结束后，我们分别打印输出训练集的困惑度，\n",
    "以及前缀“time traveler”和“traveler”的预测序列上的困惑度。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6763946e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:25:25.534734Z",
     "iopub.status.busy": "2023-08-18T07:25:25.534372Z",
     "iopub.status.idle": "2023-08-18T07:29:05.217170Z",
     "shell.execute_reply": "2023-08-18T07:29:05.216294Z"
    },
    "origin_pos": 21,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "perplexity 1.1, 19911.5 tokens/sec on cuda:0\n",
      "time traveller firenis i heidfile sook at i jomer and sugard are\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "travelleryou can show black is white by argument said filby\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"252.646875pt\" height=\"180.65625pt\" viewBox=\"0 0 252.646875 180.65625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
       " <metadata>\n",
       "  <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
       "   <cc:Work>\n",
       "    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
       "    <dc:date>2023-08-18T07:29:05.182244</dc:date>\n",
       "    <dc:format>image/svg+xml</dc:format>\n",
       "    <dc:creator>\n",
       "     <cc:Agent>\n",
       "      <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>\n",
       "     </cc:Agent>\n",
       "    </dc:creator>\n",
       "   </cc:Work>\n",
       "  </rdf:RDF>\n",
       " </metadata>\n",
       " <defs>\n",
       "  <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M 0 180.65625 \n",
       "L 252.646875 180.65625 \n",
       "L 252.646875 0 \n",
       "L 0 0 \n",
       "L 0 180.65625 \n",
       "z\n",
       "\" style=\"fill: none\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 40.603125 143.1 \n",
       "L 235.903125 143.1 \n",
       "L 235.903125 7.2 \n",
       "L 40.603125 7.2 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <path d=\"M 76.474554 143.1 \n",
       "L 76.474554 7.2 \n",
       "\" clip-path=\"url(#p6e2461c57e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_2\">\n",
       "      <defs>\n",
       "       <path id=\"md9461def62\" d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#md9461def62\" x=\"76.474554\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 100 -->\n",
       "      <g transform=\"translate(66.930804 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
       "L 1825 531 \n",
       "L 1825 4091 \n",
       "L 703 3866 \n",
       "L 703 4441 \n",
       "L 1819 4666 \n",
       "L 2450 4666 \n",
       "L 2450 531 \n",
       "L 3481 531 \n",
       "L 3481 0 \n",
       "L 794 0 \n",
       "L 794 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "        <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
       "Q 1547 4250 1301 3770 \n",
       "Q 1056 3291 1056 2328 \n",
       "Q 1056 1369 1301 889 \n",
       "Q 1547 409 2034 409 \n",
       "Q 2525 409 2770 889 \n",
       "Q 3016 1369 3016 2328 \n",
       "Q 3016 3291 2770 3770 \n",
       "Q 2525 4250 2034 4250 \n",
       "z\n",
       "M 2034 4750 \n",
       "Q 2819 4750 3233 4129 \n",
       "Q 3647 3509 3647 2328 \n",
       "Q 3647 1150 3233 529 \n",
       "Q 2819 -91 2034 -91 \n",
       "Q 1250 -91 836 529 \n",
       "Q 422 1150 422 2328 \n",
       "Q 422 3509 836 4129 \n",
       "Q 1250 4750 2034 4750 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <path d=\"M 116.331696 143.1 \n",
       "L 116.331696 7.2 \n",
       "\" clip-path=\"url(#p6e2461c57e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#md9461def62\" x=\"116.331696\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 200 -->\n",
       "      <g transform=\"translate(106.787946 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
       "L 3431 531 \n",
       "L 3431 0 \n",
       "L 469 0 \n",
       "L 469 531 \n",
       "Q 828 903 1448 1529 \n",
       "Q 2069 2156 2228 2338 \n",
       "Q 2531 2678 2651 2914 \n",
       "Q 2772 3150 2772 3378 \n",
       "Q 2772 3750 2511 3984 \n",
       "Q 2250 4219 1831 4219 \n",
       "Q 1534 4219 1204 4116 \n",
       "Q 875 4013 500 3803 \n",
       "L 500 4441 \n",
       "Q 881 4594 1212 4672 \n",
       "Q 1544 4750 1819 4750 \n",
       "Q 2544 4750 2975 4387 \n",
       "Q 3406 4025 3406 3419 \n",
       "Q 3406 3131 3298 2873 \n",
       "Q 3191 2616 2906 2266 \n",
       "Q 2828 2175 2409 1742 \n",
       "Q 1991 1309 1228 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-32\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <path d=\"M 156.188839 143.1 \n",
       "L 156.188839 7.2 \n",
       "\" clip-path=\"url(#p6e2461c57e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#md9461def62\" x=\"156.188839\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- 300 -->\n",
       "      <g transform=\"translate(146.645089 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-33\" d=\"M 2597 2516 \n",
       "Q 3050 2419 3304 2112 \n",
       "Q 3559 1806 3559 1356 \n",
       "Q 3559 666 3084 287 \n",
       "Q 2609 -91 1734 -91 \n",
       "Q 1441 -91 1130 -33 \n",
       "Q 819 25 488 141 \n",
       "L 488 750 \n",
       "Q 750 597 1062 519 \n",
       "Q 1375 441 1716 441 \n",
       "Q 2309 441 2620 675 \n",
       "Q 2931 909 2931 1356 \n",
       "Q 2931 1769 2642 2001 \n",
       "Q 2353 2234 1838 2234 \n",
       "L 1294 2234 \n",
       "L 1294 2753 \n",
       "L 1863 2753 \n",
       "Q 2328 2753 2575 2939 \n",
       "Q 2822 3125 2822 3475 \n",
       "Q 2822 3834 2567 4026 \n",
       "Q 2313 4219 1838 4219 \n",
       "Q 1578 4219 1281 4162 \n",
       "Q 984 4106 628 3988 \n",
       "L 628 4550 \n",
       "Q 988 4650 1302 4700 \n",
       "Q 1616 4750 1894 4750 \n",
       "Q 2613 4750 3031 4423 \n",
       "Q 3450 4097 3450 3541 \n",
       "Q 3450 3153 3228 2886 \n",
       "Q 3006 2619 2597 2516 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-33\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_4\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <path d=\"M 196.045982 143.1 \n",
       "L 196.045982 7.2 \n",
       "\" clip-path=\"url(#p6e2461c57e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#md9461def62\" x=\"196.045982\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 400 -->\n",
       "      <g transform=\"translate(186.502232 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
       "L 825 1625 \n",
       "L 2419 1625 \n",
       "L 2419 4116 \n",
       "z\n",
       "M 2253 4666 \n",
       "L 3047 4666 \n",
       "L 3047 1625 \n",
       "L 3713 1625 \n",
       "L 3713 1100 \n",
       "L 3047 1100 \n",
       "L 3047 0 \n",
       "L 2419 0 \n",
       "L 2419 1100 \n",
       "L 313 1100 \n",
       "L 313 1709 \n",
       "L 2253 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-34\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_5\">\n",
       "     <g id=\"line2d_9\">\n",
       "      <path d=\"M 235.903125 143.1 \n",
       "L 235.903125 7.2 \n",
       "\" clip-path=\"url(#p6e2461c57e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_10\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#md9461def62\" x=\"235.903125\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 500 -->\n",
       "      <g transform=\"translate(226.359375 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
       "L 3169 4666 \n",
       "L 3169 4134 \n",
       "L 1269 4134 \n",
       "L 1269 2991 \n",
       "Q 1406 3038 1543 3061 \n",
       "Q 1681 3084 1819 3084 \n",
       "Q 2600 3084 3056 2656 \n",
       "Q 3513 2228 3513 1497 \n",
       "Q 3513 744 3044 326 \n",
       "Q 2575 -91 1722 -91 \n",
       "Q 1428 -91 1123 -41 \n",
       "Q 819 9 494 109 \n",
       "L 494 744 \n",
       "Q 775 591 1075 516 \n",
       "Q 1375 441 1709 441 \n",
       "Q 2250 441 2565 725 \n",
       "Q 2881 1009 2881 1497 \n",
       "Q 2881 1984 2565 2268 \n",
       "Q 2250 2553 1709 2553 \n",
       "Q 1456 2553 1204 2497 \n",
       "Q 953 2441 691 2322 \n",
       "L 691 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-35\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- epoch -->\n",
       "     <g transform=\"translate(123.025 171.376563)scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
       "L 3597 1613 \n",
       "L 953 1613 \n",
       "Q 991 1019 1311 708 \n",
       "Q 1631 397 2203 397 \n",
       "Q 2534 397 2845 478 \n",
       "Q 3156 559 3463 722 \n",
       "L 3463 178 \n",
       "Q 3153 47 2828 -22 \n",
       "Q 2503 -91 2169 -91 \n",
       "Q 1331 -91 842 396 \n",
       "Q 353 884 353 1716 \n",
       "Q 353 2575 817 3079 \n",
       "Q 1281 3584 2069 3584 \n",
       "Q 2775 3584 3186 3129 \n",
       "Q 3597 2675 3597 1894 \n",
       "z\n",
       "M 3022 2063 \n",
       "Q 3016 2534 2758 2815 \n",
       "Q 2500 3097 2075 3097 \n",
       "Q 1594 3097 1305 2825 \n",
       "Q 1016 2553 972 2059 \n",
       "L 3022 2063 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-70\" d=\"M 1159 525 \n",
       "L 1159 -1331 \n",
       "L 581 -1331 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2969 \n",
       "Q 1341 3281 1617 3432 \n",
       "Q 1894 3584 2278 3584 \n",
       "Q 2916 3584 3314 3078 \n",
       "Q 3713 2572 3713 1747 \n",
       "Q 3713 922 3314 415 \n",
       "Q 2916 -91 2278 -91 \n",
       "Q 1894 -91 1617 61 \n",
       "Q 1341 213 1159 525 \n",
       "z\n",
       "M 3116 1747 \n",
       "Q 3116 2381 2855 2742 \n",
       "Q 2594 3103 2138 3103 \n",
       "Q 1681 3103 1420 2742 \n",
       "Q 1159 2381 1159 1747 \n",
       "Q 1159 1113 1420 752 \n",
       "Q 1681 391 2138 391 \n",
       "Q 2594 391 2855 752 \n",
       "Q 3116 1113 3116 1747 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
       "Q 1497 3097 1228 2736 \n",
       "Q 959 2375 959 1747 \n",
       "Q 959 1119 1226 758 \n",
       "Q 1494 397 1959 397 \n",
       "Q 2419 397 2687 759 \n",
       "Q 2956 1122 2956 1747 \n",
       "Q 2956 2369 2687 2733 \n",
       "Q 2419 3097 1959 3097 \n",
       "z\n",
       "M 1959 3584 \n",
       "Q 2709 3584 3137 3096 \n",
       "Q 3566 2609 3566 1747 \n",
       "Q 3566 888 3137 398 \n",
       "Q 2709 -91 1959 -91 \n",
       "Q 1206 -91 779 398 \n",
       "Q 353 888 353 1747 \n",
       "Q 353 2609 779 3096 \n",
       "Q 1206 3584 1959 3584 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-63\" d=\"M 3122 3366 \n",
       "L 3122 2828 \n",
       "Q 2878 2963 2633 3030 \n",
       "Q 2388 3097 2138 3097 \n",
       "Q 1578 3097 1268 2742 \n",
       "Q 959 2388 959 1747 \n",
       "Q 959 1106 1268 751 \n",
       "Q 1578 397 2138 397 \n",
       "Q 2388 397 2633 464 \n",
       "Q 2878 531 3122 666 \n",
       "L 3122 134 \n",
       "Q 2881 22 2623 -34 \n",
       "Q 2366 -91 2075 -91 \n",
       "Q 1284 -91 818 406 \n",
       "Q 353 903 353 1747 \n",
       "Q 353 2603 823 3093 \n",
       "Q 1294 3584 2113 3584 \n",
       "Q 2378 3584 2631 3529 \n",
       "Q 2884 3475 3122 3366 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-68\" d=\"M 3513 2113 \n",
       "L 3513 0 \n",
       "L 2938 0 \n",
       "L 2938 2094 \n",
       "Q 2938 2591 2744 2837 \n",
       "Q 2550 3084 2163 3084 \n",
       "Q 1697 3084 1428 2787 \n",
       "Q 1159 2491 1159 1978 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 4863 \n",
       "L 1159 4863 \n",
       "L 1159 2956 \n",
       "Q 1366 3272 1645 3428 \n",
       "Q 1925 3584 2291 3584 \n",
       "Q 2894 3584 3203 3211 \n",
       "Q 3513 2838 3513 2113 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-65\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-70\" x=\"61.523438\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6f\" x=\"125\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-63\" x=\"186.181641\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-68\" x=\"241.162109\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_11\">\n",
       "      <path d=\"M 40.603125 106.757691 \n",
       "L 235.903125 106.757691 \n",
       "\" clip-path=\"url(#p6e2461c57e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_12\">\n",
       "      <defs>\n",
       "       <path id=\"m769a88ef7f\" d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m769a88ef7f\" x=\"40.603125\" y=\"106.757691\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 5 -->\n",
       "      <g transform=\"translate(27.240625 110.556909)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_13\">\n",
       "      <path d=\"M 40.603125 68.524691 \n",
       "L 235.903125 68.524691 \n",
       "\" clip-path=\"url(#p6e2461c57e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_14\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m769a88ef7f\" x=\"40.603125\" y=\"68.524691\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 10 -->\n",
       "      <g transform=\"translate(20.878125 72.323909)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_15\">\n",
       "      <path d=\"M 40.603125 30.29169 \n",
       "L 235.903125 30.29169 \n",
       "\" clip-path=\"url(#p6e2461c57e)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_16\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m769a88ef7f\" x=\"40.603125\" y=\"30.29169\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 15 -->\n",
       "      <g transform=\"translate(20.878125 34.090909)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-35\" x=\"63.623047\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_10\">\n",
       "     <!-- perplexity -->\n",
       "     <g transform=\"translate(14.798437 100.276562)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
       "Q 2534 3019 2420 3045 \n",
       "Q 2306 3072 2169 3072 \n",
       "Q 1681 3072 1420 2755 \n",
       "Q 1159 2438 1159 1844 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2956 \n",
       "Q 1341 3275 1631 3429 \n",
       "Q 1922 3584 2338 3584 \n",
       "Q 2397 3584 2469 3576 \n",
       "Q 2541 3569 2628 3553 \n",
       "L 2631 2963 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-78\" d=\"M 3513 3500 \n",
       "L 2247 1797 \n",
       "L 3578 0 \n",
       "L 2900 0 \n",
       "L 1881 1375 \n",
       "L 863 0 \n",
       "L 184 0 \n",
       "L 1544 1831 \n",
       "L 300 3500 \n",
       "L 978 3500 \n",
       "L 1906 2253 \n",
       "L 2834 3500 \n",
       "L 3513 3500 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-69\" d=\"M 603 3500 \n",
       "L 1178 3500 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 3500 \n",
       "z\n",
       "M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 4134 \n",
       "L 603 4134 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-74\" d=\"M 1172 4494 \n",
       "L 1172 3500 \n",
       "L 2356 3500 \n",
       "L 2356 3053 \n",
       "L 1172 3053 \n",
       "L 1172 1153 \n",
       "Q 1172 725 1289 603 \n",
       "Q 1406 481 1766 481 \n",
       "L 2356 481 \n",
       "L 2356 0 \n",
       "L 1766 0 \n",
       "Q 1100 0 847 248 \n",
       "Q 594 497 594 1153 \n",
       "L 594 3053 \n",
       "L 172 3053 \n",
       "L 172 3500 \n",
       "L 594 3500 \n",
       "L 594 4494 \n",
       "L 1172 4494 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-79\" d=\"M 2059 -325 \n",
       "Q 1816 -950 1584 -1140 \n",
       "Q 1353 -1331 966 -1331 \n",
       "L 506 -1331 \n",
       "L 506 -850 \n",
       "L 844 -850 \n",
       "Q 1081 -850 1212 -737 \n",
       "Q 1344 -625 1503 -206 \n",
       "L 1606 56 \n",
       "L 191 3500 \n",
       "L 800 3500 \n",
       "L 1894 763 \n",
       "L 2988 3500 \n",
       "L 3597 3500 \n",
       "L 2059 -325 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-70\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"63.476562\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-72\" x=\"125\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-70\" x=\"166.113281\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6c\" x=\"229.589844\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"257.373047\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-78\" x=\"317.146484\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"376.326172\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-74\" x=\"404.109375\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-79\" x=\"443.318359\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"line2d_17\">\n",
       "    <path d=\"M 40.603125 13.377273 \n",
       "L 44.588839 25.637575 \n",
       "L 48.574554 47.126714 \n",
       "L 52.560268 55.729213 \n",
       "L 56.545982 62.18617 \n",
       "L 60.531696 65.683996 \n",
       "L 64.517411 69.981808 \n",
       "L 68.503125 70.353781 \n",
       "L 72.488839 74.58104 \n",
       "L 76.474554 77.313672 \n",
       "L 80.460268 79.576411 \n",
       "L 84.445982 81.704648 \n",
       "L 88.431696 82.872417 \n",
       "L 92.417411 85.559844 \n",
       "L 96.403125 87.509689 \n",
       "L 100.388839 88.931127 \n",
       "L 104.374554 91.718442 \n",
       "L 108.360268 93.665367 \n",
       "L 112.345982 95.725488 \n",
       "L 116.331696 97.295552 \n",
       "L 120.317411 100.005358 \n",
       "L 124.303125 101.85475 \n",
       "L 128.288839 104.596201 \n",
       "L 132.274554 106.382443 \n",
       "L 136.260268 108.970279 \n",
       "L 140.245982 111.209329 \n",
       "L 144.231696 113.566362 \n",
       "L 148.217411 116.070633 \n",
       "L 152.203125 119.264721 \n",
       "L 156.188839 122.333673 \n",
       "L 160.174554 124.422252 \n",
       "L 164.160268 126.890865 \n",
       "L 168.145982 128.406277 \n",
       "L 172.131696 130.614493 \n",
       "L 176.117411 133.039091 \n",
       "L 180.103125 133.812798 \n",
       "L 184.088839 133.844135 \n",
       "L 188.074554 135.552256 \n",
       "L 192.060268 135.203639 \n",
       "L 196.045982 136.122186 \n",
       "L 200.031696 136.212058 \n",
       "L 204.017411 136.471822 \n",
       "L 208.003125 136.590367 \n",
       "L 211.988839 136.010159 \n",
       "L 215.974554 136.620047 \n",
       "L 219.960268 136.749273 \n",
       "L 223.945982 136.775328 \n",
       "L 227.931696 136.738094 \n",
       "L 231.917411 136.831766 \n",
       "L 235.903125 136.922727 \n",
       "\" clip-path=\"url(#p6e2461c57e)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 40.603125 143.1 \n",
       "L 40.603125 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 235.903125 143.1 \n",
       "L 235.903125 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 40.603125 143.1 \n",
       "L 235.903125 143.1 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 40.603125 7.2 \n",
       "L 235.903125 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"legend_1\">\n",
       "    <g id=\"patch_7\">\n",
       "     <path d=\"M 173.628125 29.878125 \n",
       "L 228.903125 29.878125 \n",
       "Q 230.903125 29.878125 230.903125 27.878125 \n",
       "L 230.903125 14.2 \n",
       "Q 230.903125 12.2 228.903125 12.2 \n",
       "L 173.628125 12.2 \n",
       "Q 171.628125 12.2 171.628125 14.2 \n",
       "L 171.628125 27.878125 \n",
       "Q 171.628125 29.878125 173.628125 29.878125 \n",
       "z\n",
       "\" style=\"fill: #ffffff; opacity: 0.8; stroke: #cccccc; stroke-linejoin: miter\"/>\n",
       "    </g>\n",
       "    <g id=\"line2d_18\">\n",
       "     <path d=\"M 175.628125 20.298437 \n",
       "L 185.628125 20.298437 \n",
       "L 195.628125 20.298437 \n",
       "\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
       "    </g>\n",
       "    <g id=\"text_11\">\n",
       "     <!-- train -->\n",
       "     <g transform=\"translate(203.628125 23.798437)scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-61\" d=\"M 2194 1759 \n",
       "Q 1497 1759 1228 1600 \n",
       "Q 959 1441 959 1056 \n",
       "Q 959 750 1161 570 \n",
       "Q 1363 391 1709 391 \n",
       "Q 2188 391 2477 730 \n",
       "Q 2766 1069 2766 1631 \n",
       "L 2766 1759 \n",
       "L 2194 1759 \n",
       "z\n",
       "M 3341 1997 \n",
       "L 3341 0 \n",
       "L 2766 0 \n",
       "L 2766 531 \n",
       "Q 2569 213 2275 61 \n",
       "Q 1981 -91 1556 -91 \n",
       "Q 1019 -91 701 211 \n",
       "Q 384 513 384 1019 \n",
       "Q 384 1609 779 1909 \n",
       "Q 1175 2209 1959 2209 \n",
       "L 2766 2209 \n",
       "L 2766 2266 \n",
       "Q 2766 2663 2505 2880 \n",
       "Q 2244 3097 1772 3097 \n",
       "Q 1472 3097 1187 3025 \n",
       "Q 903 2953 641 2809 \n",
       "L 641 3341 \n",
       "Q 956 3463 1253 3523 \n",
       "Q 1550 3584 1831 3584 \n",
       "Q 2591 3584 2966 3190 \n",
       "Q 3341 2797 3341 1997 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6e\" d=\"M 3513 2113 \n",
       "L 3513 0 \n",
       "L 2938 0 \n",
       "L 2938 2094 \n",
       "Q 2938 2591 2744 2837 \n",
       "Q 2550 3084 2163 3084 \n",
       "Q 1697 3084 1428 2787 \n",
       "Q 1159 2491 1159 1978 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2956 \n",
       "Q 1366 3272 1645 3428 \n",
       "Q 1925 3584 2291 3584 \n",
       "Q 2894 3584 3203 3211 \n",
       "Q 3513 2838 3513 2113 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-74\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-72\" x=\"39.208984\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-61\" x=\"80.322266\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"141.601562\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6e\" x=\"169.384766\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"p6e2461c57e\">\n",
       "   <rect x=\"40.603125\" y=\"7.2\" width=\"195.3\" height=\"135.9\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()\n",
    "num_epochs, lr = 500, 1\n",
    "model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,\n",
    "                            init_gru_state, gru)\n",
    "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbac9ebd",
   "metadata": {
    "origin_pos": 24
   },
   "source": [
    "## [**简洁实现**]\n",
    "\n",
    "高级API包含了前文介绍的所有配置细节，\n",
    "所以我们可以直接实例化门控循环单元模型。\n",
    "这段代码的运行速度要快得多，\n",
    "因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6549d929",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T07:29:05.220877Z",
     "iopub.status.busy": "2023-08-18T07:29:05.220574Z",
     "iopub.status.idle": "2023-08-18T07:29:30.037933Z",
     "shell.execute_reply": "2023-08-18T07:29:30.036733Z"
    },
    "origin_pos": 26,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "perplexity 1.0, 109423.8 tokens/sec on cuda:0\n",
      "time travelleryou can show black is white by argument said filby\n",
      "traveller with a slight accession ofcheerfulness really thi\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<svg xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"252.646875pt\" height=\"180.65625pt\" viewBox=\"0 0 252.646875 180.65625\" xmlns=\"http://www.w3.org/2000/svg\" version=\"1.1\">\n",
       " <metadata>\n",
       "  <rdf:RDF xmlns:dc=\"http://purl.org/dc/elements/1.1/\" xmlns:cc=\"http://creativecommons.org/ns#\" xmlns:rdf=\"http://www.w3.org/1999/02/22-rdf-syntax-ns#\">\n",
       "   <cc:Work>\n",
       "    <dc:type rdf:resource=\"http://purl.org/dc/dcmitype/StillImage\"/>\n",
       "    <dc:date>2023-08-18T07:29:30.001099</dc:date>\n",
       "    <dc:format>image/svg+xml</dc:format>\n",
       "    <dc:creator>\n",
       "     <cc:Agent>\n",
       "      <dc:title>Matplotlib v3.5.1, https://matplotlib.org/</dc:title>\n",
       "     </cc:Agent>\n",
       "    </dc:creator>\n",
       "   </cc:Work>\n",
       "  </rdf:RDF>\n",
       " </metadata>\n",
       " <defs>\n",
       "  <style type=\"text/css\">*{stroke-linejoin: round; stroke-linecap: butt}</style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M 0 180.65625 \n",
       "L 252.646875 180.65625 \n",
       "L 252.646875 0 \n",
       "L 0 0 \n",
       "L 0 180.65625 \n",
       "z\n",
       "\" style=\"fill: none\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 40.603125 143.1 \n",
       "L 235.903125 143.1 \n",
       "L 235.903125 7.2 \n",
       "L 40.603125 7.2 \n",
       "z\n",
       "\" style=\"fill: #ffffff\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <path d=\"M 76.474554 143.1 \n",
       "L 76.474554 7.2 \n",
       "\" clip-path=\"url(#p0c61bccd78)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_2\">\n",
       "      <defs>\n",
       "       <path id=\"m496c82adcf\" d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m496c82adcf\" x=\"76.474554\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 100 -->\n",
       "      <g transform=\"translate(66.930804 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-31\" d=\"M 794 531 \n",
       "L 1825 531 \n",
       "L 1825 4091 \n",
       "L 703 3866 \n",
       "L 703 4441 \n",
       "L 1819 4666 \n",
       "L 2450 4666 \n",
       "L 2450 531 \n",
       "L 3481 531 \n",
       "L 3481 0 \n",
       "L 794 0 \n",
       "L 794 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "        <path id=\"DejaVuSans-30\" d=\"M 2034 4250 \n",
       "Q 1547 4250 1301 3770 \n",
       "Q 1056 3291 1056 2328 \n",
       "Q 1056 1369 1301 889 \n",
       "Q 1547 409 2034 409 \n",
       "Q 2525 409 2770 889 \n",
       "Q 3016 1369 3016 2328 \n",
       "Q 3016 3291 2770 3770 \n",
       "Q 2525 4250 2034 4250 \n",
       "z\n",
       "M 2034 4750 \n",
       "Q 2819 4750 3233 4129 \n",
       "Q 3647 3509 3647 2328 \n",
       "Q 3647 1150 3233 529 \n",
       "Q 2819 -91 2034 -91 \n",
       "Q 1250 -91 836 529 \n",
       "Q 422 1150 422 2328 \n",
       "Q 422 3509 836 4129 \n",
       "Q 1250 4750 2034 4750 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <path d=\"M 116.331696 143.1 \n",
       "L 116.331696 7.2 \n",
       "\" clip-path=\"url(#p0c61bccd78)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m496c82adcf\" x=\"116.331696\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 200 -->\n",
       "      <g transform=\"translate(106.787946 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-32\" d=\"M 1228 531 \n",
       "L 3431 531 \n",
       "L 3431 0 \n",
       "L 469 0 \n",
       "L 469 531 \n",
       "Q 828 903 1448 1529 \n",
       "Q 2069 2156 2228 2338 \n",
       "Q 2531 2678 2651 2914 \n",
       "Q 2772 3150 2772 3378 \n",
       "Q 2772 3750 2511 3984 \n",
       "Q 2250 4219 1831 4219 \n",
       "Q 1534 4219 1204 4116 \n",
       "Q 875 4013 500 3803 \n",
       "L 500 4441 \n",
       "Q 881 4594 1212 4672 \n",
       "Q 1544 4750 1819 4750 \n",
       "Q 2544 4750 2975 4387 \n",
       "Q 3406 4025 3406 3419 \n",
       "Q 3406 3131 3298 2873 \n",
       "Q 3191 2616 2906 2266 \n",
       "Q 2828 2175 2409 1742 \n",
       "Q 1991 1309 1228 531 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-32\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <path d=\"M 156.188839 143.1 \n",
       "L 156.188839 7.2 \n",
       "\" clip-path=\"url(#p0c61bccd78)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m496c82adcf\" x=\"156.188839\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- 300 -->\n",
       "      <g transform=\"translate(146.645089 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-33\" d=\"M 2597 2516 \n",
       "Q 3050 2419 3304 2112 \n",
       "Q 3559 1806 3559 1356 \n",
       "Q 3559 666 3084 287 \n",
       "Q 2609 -91 1734 -91 \n",
       "Q 1441 -91 1130 -33 \n",
       "Q 819 25 488 141 \n",
       "L 488 750 \n",
       "Q 750 597 1062 519 \n",
       "Q 1375 441 1716 441 \n",
       "Q 2309 441 2620 675 \n",
       "Q 2931 909 2931 1356 \n",
       "Q 2931 1769 2642 2001 \n",
       "Q 2353 2234 1838 2234 \n",
       "L 1294 2234 \n",
       "L 1294 2753 \n",
       "L 1863 2753 \n",
       "Q 2328 2753 2575 2939 \n",
       "Q 2822 3125 2822 3475 \n",
       "Q 2822 3834 2567 4026 \n",
       "Q 2313 4219 1838 4219 \n",
       "Q 1578 4219 1281 4162 \n",
       "Q 984 4106 628 3988 \n",
       "L 628 4550 \n",
       "Q 988 4650 1302 4700 \n",
       "Q 1616 4750 1894 4750 \n",
       "Q 2613 4750 3031 4423 \n",
       "Q 3450 4097 3450 3541 \n",
       "Q 3450 3153 3228 2886 \n",
       "Q 3006 2619 2597 2516 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-33\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_4\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <path d=\"M 196.045982 143.1 \n",
       "L 196.045982 7.2 \n",
       "\" clip-path=\"url(#p0c61bccd78)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m496c82adcf\" x=\"196.045982\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 400 -->\n",
       "      <g transform=\"translate(186.502232 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-34\" d=\"M 2419 4116 \n",
       "L 825 1625 \n",
       "L 2419 1625 \n",
       "L 2419 4116 \n",
       "z\n",
       "M 2253 4666 \n",
       "L 3047 4666 \n",
       "L 3047 1625 \n",
       "L 3713 1625 \n",
       "L 3713 1100 \n",
       "L 3047 1100 \n",
       "L 3047 0 \n",
       "L 2419 0 \n",
       "L 2419 1100 \n",
       "L 313 1100 \n",
       "L 313 1709 \n",
       "L 2253 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-34\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_5\">\n",
       "     <g id=\"line2d_9\">\n",
       "      <path d=\"M 235.903125 143.1 \n",
       "L 235.903125 7.2 \n",
       "\" clip-path=\"url(#p0c61bccd78)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_10\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m496c82adcf\" x=\"235.903125\" y=\"143.1\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 500 -->\n",
       "      <g transform=\"translate(226.359375 157.698438)scale(0.1 -0.1)\">\n",
       "       <defs>\n",
       "        <path id=\"DejaVuSans-35\" d=\"M 691 4666 \n",
       "L 3169 4666 \n",
       "L 3169 4134 \n",
       "L 1269 4134 \n",
       "L 1269 2991 \n",
       "Q 1406 3038 1543 3061 \n",
       "Q 1681 3084 1819 3084 \n",
       "Q 2600 3084 3056 2656 \n",
       "Q 3513 2228 3513 1497 \n",
       "Q 3513 744 3044 326 \n",
       "Q 2575 -91 1722 -91 \n",
       "Q 1428 -91 1123 -41 \n",
       "Q 819 9 494 109 \n",
       "L 494 744 \n",
       "Q 775 591 1075 516 \n",
       "Q 1375 441 1709 441 \n",
       "Q 2250 441 2565 725 \n",
       "Q 2881 1009 2881 1497 \n",
       "Q 2881 1984 2565 2268 \n",
       "Q 2250 2553 1709 2553 \n",
       "Q 1456 2553 1204 2497 \n",
       "Q 953 2441 691 2322 \n",
       "L 691 4666 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       </defs>\n",
       "       <use xlink:href=\"#DejaVuSans-35\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"127.246094\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- epoch -->\n",
       "     <g transform=\"translate(123.025 171.376563)scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-65\" d=\"M 3597 1894 \n",
       "L 3597 1613 \n",
       "L 953 1613 \n",
       "Q 991 1019 1311 708 \n",
       "Q 1631 397 2203 397 \n",
       "Q 2534 397 2845 478 \n",
       "Q 3156 559 3463 722 \n",
       "L 3463 178 \n",
       "Q 3153 47 2828 -22 \n",
       "Q 2503 -91 2169 -91 \n",
       "Q 1331 -91 842 396 \n",
       "Q 353 884 353 1716 \n",
       "Q 353 2575 817 3079 \n",
       "Q 1281 3584 2069 3584 \n",
       "Q 2775 3584 3186 3129 \n",
       "Q 3597 2675 3597 1894 \n",
       "z\n",
       "M 3022 2063 \n",
       "Q 3016 2534 2758 2815 \n",
       "Q 2500 3097 2075 3097 \n",
       "Q 1594 3097 1305 2825 \n",
       "Q 1016 2553 972 2059 \n",
       "L 3022 2063 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-70\" d=\"M 1159 525 \n",
       "L 1159 -1331 \n",
       "L 581 -1331 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2969 \n",
       "Q 1341 3281 1617 3432 \n",
       "Q 1894 3584 2278 3584 \n",
       "Q 2916 3584 3314 3078 \n",
       "Q 3713 2572 3713 1747 \n",
       "Q 3713 922 3314 415 \n",
       "Q 2916 -91 2278 -91 \n",
       "Q 1894 -91 1617 61 \n",
       "Q 1341 213 1159 525 \n",
       "z\n",
       "M 3116 1747 \n",
       "Q 3116 2381 2855 2742 \n",
       "Q 2594 3103 2138 3103 \n",
       "Q 1681 3103 1420 2742 \n",
       "Q 1159 2381 1159 1747 \n",
       "Q 1159 1113 1420 752 \n",
       "Q 1681 391 2138 391 \n",
       "Q 2594 391 2855 752 \n",
       "Q 3116 1113 3116 1747 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6f\" d=\"M 1959 3097 \n",
       "Q 1497 3097 1228 2736 \n",
       "Q 959 2375 959 1747 \n",
       "Q 959 1119 1226 758 \n",
       "Q 1494 397 1959 397 \n",
       "Q 2419 397 2687 759 \n",
       "Q 2956 1122 2956 1747 \n",
       "Q 2956 2369 2687 2733 \n",
       "Q 2419 3097 1959 3097 \n",
       "z\n",
       "M 1959 3584 \n",
       "Q 2709 3584 3137 3096 \n",
       "Q 3566 2609 3566 1747 \n",
       "Q 3566 888 3137 398 \n",
       "Q 2709 -91 1959 -91 \n",
       "Q 1206 -91 779 398 \n",
       "Q 353 888 353 1747 \n",
       "Q 353 2609 779 3096 \n",
       "Q 1206 3584 1959 3584 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-63\" d=\"M 3122 3366 \n",
       "L 3122 2828 \n",
       "Q 2878 2963 2633 3030 \n",
       "Q 2388 3097 2138 3097 \n",
       "Q 1578 3097 1268 2742 \n",
       "Q 959 2388 959 1747 \n",
       "Q 959 1106 1268 751 \n",
       "Q 1578 397 2138 397 \n",
       "Q 2388 397 2633 464 \n",
       "Q 2878 531 3122 666 \n",
       "L 3122 134 \n",
       "Q 2881 22 2623 -34 \n",
       "Q 2366 -91 2075 -91 \n",
       "Q 1284 -91 818 406 \n",
       "Q 353 903 353 1747 \n",
       "Q 353 2603 823 3093 \n",
       "Q 1294 3584 2113 3584 \n",
       "Q 2378 3584 2631 3529 \n",
       "Q 2884 3475 3122 3366 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-68\" d=\"M 3513 2113 \n",
       "L 3513 0 \n",
       "L 2938 0 \n",
       "L 2938 2094 \n",
       "Q 2938 2591 2744 2837 \n",
       "Q 2550 3084 2163 3084 \n",
       "Q 1697 3084 1428 2787 \n",
       "Q 1159 2491 1159 1978 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 4863 \n",
       "L 1159 4863 \n",
       "L 1159 2956 \n",
       "Q 1366 3272 1645 3428 \n",
       "Q 1925 3584 2291 3584 \n",
       "Q 2894 3584 3203 3211 \n",
       "Q 3513 2838 3513 2113 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-65\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-70\" x=\"61.523438\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6f\" x=\"125\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-63\" x=\"186.181641\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-68\" x=\"241.162109\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_11\">\n",
       "      <path d=\"M 40.603125 103.461216 \n",
       "L 235.903125 103.461216 \n",
       "\" clip-path=\"url(#p0c61bccd78)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_12\">\n",
       "      <defs>\n",
       "       <path id=\"m108f3f10aa\" d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use xlink:href=\"#m108f3f10aa\" x=\"40.603125\" y=\"103.461216\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 5 -->\n",
       "      <g transform=\"translate(27.240625 107.260435)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_13\">\n",
       "      <path d=\"M 40.603125 61.252441 \n",
       "L 235.903125 61.252441 \n",
       "\" clip-path=\"url(#p0c61bccd78)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_14\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m108f3f10aa\" x=\"40.603125\" y=\"61.252441\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 10 -->\n",
       "      <g transform=\"translate(20.878125 65.05166)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-30\" x=\"63.623047\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_15\">\n",
       "      <path d=\"M 40.603125 19.043666 \n",
       "L 235.903125 19.043666 \n",
       "\" clip-path=\"url(#p0c61bccd78)\" style=\"fill: none; stroke: #b0b0b0; stroke-width: 0.8; stroke-linecap: square\"/>\n",
       "     </g>\n",
       "     <g id=\"line2d_16\">\n",
       "      <g>\n",
       "       <use xlink:href=\"#m108f3f10aa\" x=\"40.603125\" y=\"19.043666\" style=\"stroke: #000000; stroke-width: 0.8\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 15 -->\n",
       "      <g transform=\"translate(20.878125 22.842885)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use xlink:href=\"#DejaVuSans-35\" x=\"63.623047\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_10\">\n",
       "     <!-- perplexity -->\n",
       "     <g transform=\"translate(14.798437 100.276562)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-72\" d=\"M 2631 2963 \n",
       "Q 2534 3019 2420 3045 \n",
       "Q 2306 3072 2169 3072 \n",
       "Q 1681 3072 1420 2755 \n",
       "Q 1159 2438 1159 1844 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2956 \n",
       "Q 1341 3275 1631 3429 \n",
       "Q 1922 3584 2338 3584 \n",
       "Q 2397 3584 2469 3576 \n",
       "Q 2541 3569 2628 3553 \n",
       "L 2631 2963 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6c\" d=\"M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-78\" d=\"M 3513 3500 \n",
       "L 2247 1797 \n",
       "L 3578 0 \n",
       "L 2900 0 \n",
       "L 1881 1375 \n",
       "L 863 0 \n",
       "L 184 0 \n",
       "L 1544 1831 \n",
       "L 300 3500 \n",
       "L 978 3500 \n",
       "L 1906 2253 \n",
       "L 2834 3500 \n",
       "L 3513 3500 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-69\" d=\"M 603 3500 \n",
       "L 1178 3500 \n",
       "L 1178 0 \n",
       "L 603 0 \n",
       "L 603 3500 \n",
       "z\n",
       "M 603 4863 \n",
       "L 1178 4863 \n",
       "L 1178 4134 \n",
       "L 603 4134 \n",
       "L 603 4863 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-74\" d=\"M 1172 4494 \n",
       "L 1172 3500 \n",
       "L 2356 3500 \n",
       "L 2356 3053 \n",
       "L 1172 3053 \n",
       "L 1172 1153 \n",
       "Q 1172 725 1289 603 \n",
       "Q 1406 481 1766 481 \n",
       "L 2356 481 \n",
       "L 2356 0 \n",
       "L 1766 0 \n",
       "Q 1100 0 847 248 \n",
       "Q 594 497 594 1153 \n",
       "L 594 3053 \n",
       "L 172 3053 \n",
       "L 172 3500 \n",
       "L 594 3500 \n",
       "L 594 4494 \n",
       "L 1172 4494 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-79\" d=\"M 2059 -325 \n",
       "Q 1816 -950 1584 -1140 \n",
       "Q 1353 -1331 966 -1331 \n",
       "L 506 -1331 \n",
       "L 506 -850 \n",
       "L 844 -850 \n",
       "Q 1081 -850 1212 -737 \n",
       "Q 1344 -625 1503 -206 \n",
       "L 1606 56 \n",
       "L 191 3500 \n",
       "L 800 3500 \n",
       "L 1894 763 \n",
       "L 2988 3500 \n",
       "L 3597 3500 \n",
       "L 2059 -325 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-70\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"63.476562\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-72\" x=\"125\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-70\" x=\"166.113281\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6c\" x=\"229.589844\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-65\" x=\"257.373047\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-78\" x=\"317.146484\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"376.326172\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-74\" x=\"404.109375\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-79\" x=\"443.318359\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"line2d_17\">\n",
       "    <path d=\"M 40.603125 13.377273 \n",
       "L 44.588839 43.014056 \n",
       "L 48.574554 54.439401 \n",
       "L 52.560268 60.75052 \n",
       "L 56.545982 64.158181 \n",
       "L 60.531696 69.803517 \n",
       "L 64.517411 71.744403 \n",
       "L 68.503125 76.269325 \n",
       "L 72.488839 76.588283 \n",
       "L 76.474554 81.440421 \n",
       "L 80.460268 83.259675 \n",
       "L 84.445982 86.099556 \n",
       "L 88.431696 89.240547 \n",
       "L 92.417411 92.098371 \n",
       "L 96.403125 95.373352 \n",
       "L 100.388839 97.512292 \n",
       "L 104.374554 102.63836 \n",
       "L 108.360268 106.679603 \n",
       "L 112.345982 109.436896 \n",
       "L 116.331696 114.488164 \n",
       "L 120.317411 118.418205 \n",
       "L 124.303125 122.487526 \n",
       "L 128.288839 125.567831 \n",
       "L 132.274554 129.010677 \n",
       "L 136.260268 132.111659 \n",
       "L 140.245982 134.083598 \n",
       "L 144.231696 134.983747 \n",
       "L 148.217411 135.346774 \n",
       "L 152.203125 135.984109 \n",
       "L 156.188839 136.279846 \n",
       "L 160.174554 136.150569 \n",
       "L 164.160268 136.459062 \n",
       "L 168.145982 136.533881 \n",
       "L 172.131696 136.580315 \n",
       "L 176.117411 136.649642 \n",
       "L 180.103125 136.641676 \n",
       "L 184.088839 136.397526 \n",
       "L 188.074554 136.701189 \n",
       "L 192.060268 136.812116 \n",
       "L 196.045982 136.801627 \n",
       "L 200.031696 136.831185 \n",
       "L 204.017411 136.812126 \n",
       "L 208.003125 136.521114 \n",
       "L 211.988839 136.816221 \n",
       "L 215.974554 136.736076 \n",
       "L 219.960268 136.812743 \n",
       "L 223.945982 136.922727 \n",
       "L 227.931696 136.837235 \n",
       "L 231.917411 136.910329 \n",
       "L 235.903125 136.921131 \n",
       "\" clip-path=\"url(#p0c61bccd78)\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 40.603125 143.1 \n",
       "L 40.603125 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 235.903125 143.1 \n",
       "L 235.903125 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 40.603125 143.1 \n",
       "L 235.903125 143.1 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 40.603125 7.2 \n",
       "L 235.903125 7.2 \n",
       "\" style=\"fill: none; stroke: #000000; stroke-width: 0.8; stroke-linejoin: miter; stroke-linecap: square\"/>\n",
       "   </g>\n",
       "   <g id=\"legend_1\">\n",
       "    <g id=\"patch_7\">\n",
       "     <path d=\"M 173.628125 29.878125 \n",
       "L 228.903125 29.878125 \n",
       "Q 230.903125 29.878125 230.903125 27.878125 \n",
       "L 230.903125 14.2 \n",
       "Q 230.903125 12.2 228.903125 12.2 \n",
       "L 173.628125 12.2 \n",
       "Q 171.628125 12.2 171.628125 14.2 \n",
       "L 171.628125 27.878125 \n",
       "Q 171.628125 29.878125 173.628125 29.878125 \n",
       "z\n",
       "\" style=\"fill: #ffffff; opacity: 0.8; stroke: #cccccc; stroke-linejoin: miter\"/>\n",
       "    </g>\n",
       "    <g id=\"line2d_18\">\n",
       "     <path d=\"M 175.628125 20.298437 \n",
       "L 185.628125 20.298437 \n",
       "L 195.628125 20.298437 \n",
       "\" style=\"fill: none; stroke: #1f77b4; stroke-width: 1.5; stroke-linecap: square\"/>\n",
       "    </g>\n",
       "    <g id=\"text_11\">\n",
       "     <!-- train -->\n",
       "     <g transform=\"translate(203.628125 23.798437)scale(0.1 -0.1)\">\n",
       "      <defs>\n",
       "       <path id=\"DejaVuSans-61\" d=\"M 2194 1759 \n",
       "Q 1497 1759 1228 1600 \n",
       "Q 959 1441 959 1056 \n",
       "Q 959 750 1161 570 \n",
       "Q 1363 391 1709 391 \n",
       "Q 2188 391 2477 730 \n",
       "Q 2766 1069 2766 1631 \n",
       "L 2766 1759 \n",
       "L 2194 1759 \n",
       "z\n",
       "M 3341 1997 \n",
       "L 3341 0 \n",
       "L 2766 0 \n",
       "L 2766 531 \n",
       "Q 2569 213 2275 61 \n",
       "Q 1981 -91 1556 -91 \n",
       "Q 1019 -91 701 211 \n",
       "Q 384 513 384 1019 \n",
       "Q 384 1609 779 1909 \n",
       "Q 1175 2209 1959 2209 \n",
       "L 2766 2209 \n",
       "L 2766 2266 \n",
       "Q 2766 2663 2505 2880 \n",
       "Q 2244 3097 1772 3097 \n",
       "Q 1472 3097 1187 3025 \n",
       "Q 903 2953 641 2809 \n",
       "L 641 3341 \n",
       "Q 956 3463 1253 3523 \n",
       "Q 1550 3584 1831 3584 \n",
       "Q 2591 3584 2966 3190 \n",
       "Q 3341 2797 3341 1997 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "       <path id=\"DejaVuSans-6e\" d=\"M 3513 2113 \n",
       "L 3513 0 \n",
       "L 2938 0 \n",
       "L 2938 2094 \n",
       "Q 2938 2591 2744 2837 \n",
       "Q 2550 3084 2163 3084 \n",
       "Q 1697 3084 1428 2787 \n",
       "Q 1159 2491 1159 1978 \n",
       "L 1159 0 \n",
       "L 581 0 \n",
       "L 581 3500 \n",
       "L 1159 3500 \n",
       "L 1159 2956 \n",
       "Q 1366 3272 1645 3428 \n",
       "Q 1925 3584 2291 3584 \n",
       "Q 2894 3584 3203 3211 \n",
       "Q 3513 2838 3513 2113 \n",
       "z\n",
       "\" transform=\"scale(0.015625)\"/>\n",
       "      </defs>\n",
       "      <use xlink:href=\"#DejaVuSans-74\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-72\" x=\"39.208984\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-61\" x=\"80.322266\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-69\" x=\"141.601562\"/>\n",
       "      <use xlink:href=\"#DejaVuSans-6e\" x=\"169.384766\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"p0c61bccd78\">\n",
       "   <rect x=\"40.603125\" y=\"7.2\" width=\"195.3\" height=\"135.9\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "num_inputs = vocab_size\n",
    "gru_layer = nn.GRU(num_inputs, num_hiddens)\n",
    "model = d2l.RNNModel(gru_layer, len(vocab))\n",
    "model = model.to(device)\n",
    "d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f4a971d",
   "metadata": {
    "origin_pos": 29
   },
   "source": [
    "## 小结\n",
    "\n",
    "* 门控循环神经网络可以更好地捕获时间步距离很长的序列上的依赖关系。\n",
    "* 重置门有助于捕获序列中的短期依赖关系。\n",
    "* 更新门有助于捕获序列中的长期依赖关系。\n",
    "* 重置门打开时，门控循环单元包含基本循环神经网络；更新门打开时，门控循环单元可以跳过子序列。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 假设我们只想使用时间步$t'$的输入来预测时间步$t > t'$的输出。对于每个时间步，重置门和更新门的最佳值是什么？\n",
    "1. 调整和分析超参数对运行时间、困惑度和输出顺序的影响。\n",
    "1. 比较`rnn.RNN`和`rnn.GRU`的不同实现对运行时间、困惑度和输出字符串的影响。\n",
    "1. 如果仅仅实现门控循环单元的一部分，例如，只有一个重置门或一个更新门会怎样？\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbedd41b",
   "metadata": {
    "origin_pos": 31,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/2763)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}